# This file is a sketch on how a full BS integration for all of CNTK would look like.
# This used to work, but is not maintained presently.

# these are defined outside, but currently this is not implemented (we need the "with" operator)
stderr='..\RunDir\LSTM\Truncated\models\cntkSpeech.dnn.log'
RunDir='..\RunDir\LSTM\Truncated'
NdlDir='..\LSTM'
DataDir='.'
DeviceId=0          // 'auto'  This does not work since too many places just pass deviceId as a numeric value. Add a BS function.
makeMode=false


precision = 'float'
actions = speechTrain
deviceId = DeviceId       // defined outside

parallelTrain = false

frameMode = false
truncated = true

speechTrain = new TrainAction [
    modelPath = RunDir + "/models/cntkSpeech.dnn"
    traceLevel = 1
    
    optimizer = new SGDOptimizer [
        epochSize = 20480
        minibatchSize = 20
        learningRatesPerMB = 0.5
        numMBsToShowResult = 10
        momentumPerMB = 0:0.9
        maxEpochs = 4
        keepCheckPointFiles = true       
    ]

    reader = new DataReaderPlugin [
      readerType = 'HTKMLFReader'
      readMethod = 'blockRandomize'
      minibatchMode = 'partial'
      nbruttsineachrecurrentiter = 32
      randomize = 'auto'
      verbosity = 0
      features = [
          dim = 363
          type = 'real'
          scpFile = DataDir + "/glob_0000.scp"
      ]
  
      labels = [
          mlfFile = DataDir + "/glob_0000.mlf"
          labelMappingFile = DataDir + "/state.list"
        
          labelDim = 132
          labelType = 'category'
      ]
    ]

    # define network using BrainScript
    createNetwork() = new ComputationNetwork [
        
        WeightParam(m,n) = Parameter(m, n, init='uniform', initValueScale=1, initOnCPUOnly=true, randomSeed=1)
        BiasParam(m) = Parameter(m, 1, init='fixedValue', value=0.0)
        ScalarParam() = Parameter(1, 1, init='fixedValue', value=0.0)

        NewBeta() = Exp(ScalarParam())
        Stabilize(in) = Scale(NewBeta(), in)

        LSTMPComponentWithSelfStab(inputDim, outputDim, cellDim, inputx) =
        [
            // parameter macros--these carry their own weight matrices
            B() = BiasParam(cellDim)
            Wmr = WeightParam(outputDim, cellDim);

            W(v) = WeightParam(cellDim, inputDim) * Stabilize(v)    // input-to-hidden
            H(h) = WeightParam(cellDim, outputDim) * Stabilize(h)   // hidden-to-hidden
            C(c) = DiagTimes(WeightParam(cellDim, 1), Stabilize(c)) // cell-to-hiddden

            // LSTM cell
            dh = PastValue(outputDim, output);                   // hidden state(t-1)
            dc = PastValue(cellDim, ct);                         // cell(t-1)

            // note: the W(inputx) here are all different, they all come with their own set of weights; same for H(dh), C(dc), and B()
            it = Sigmoid(W(inputx) + B() + H(dh) + C(dc))       // input gate(t)
            bit = it .* Tanh(W(inputx) + (H(dh) + B()))         // applied to tanh of input network

            ft = Sigmoid(W(inputx) + B() + H(dh) + C(dc))       // forget-me-not gate(t)
            bft = ft .* dc                                          // applied to cell(t-1)

            ct = bft + bit                                          // c(t) is sum of both

            ot = Sigmoid(W(inputx) + B() + H(dh) + C(ct))       // output gate(t)
            mt = ot .* Tanh(ct)                                     // applied to tanh(cell(t))

            output = Wmr * Stabilize(mt)                            // projection
        ]

        // define basic I/O
        baseFeatDim = 33
        featDim = 11 * baseFeatDim      // TODO: 363--is this the correct explanation?
        labelDim = 132

        // hidden dimensions
        cellDim = 1024
        hiddenDim = 256
        numLSTMs = 3        // number of hidden LSTM model layers

        // features
        features = Input(featDim, 1, tag='feature')
        labels = Input(labelDim, 1, tag='label')
        feashift = RowSlice(featDim - baseFeatDim, baseFeatDim, features);      # shift 5 frames right (x_{t+5} -> x_{t} )  // TODO why 5? Where do I see this?

        featNorm = MeanVarNorm(feashift)

        // define the stack of hidden LSTM layers
        LSTMoutput[k:1..numLSTMs] = if k == 1
                                    then LSTMPComponentWithSelfStab(baseFeatDim, hiddenDim, cellDim, featNorm)
                                    else LSTMPComponentWithSelfStab(hiddenDim,   hiddenDim, cellDim, LSTMoutput[k-1].output)

        // and add a softmax layer on top
        W(in) = WeightParam(labelDim, hiddenDim) * Stabilize(in)
        B = BiasParam(labelDim)
        
        LSTMoutputW = W(LSTMoutput[numLSTMs].output) + B;

        // training
        cr = CrossEntropyWithSoftmax(labels, LSTMoutputW, tag='criterion')  // this is the objective
        Err = ClassificationError(labels, LSTMoutputW, tag='evaluation')              // this also gets tracked

        // decoding
        logPrior = LogPrior(labels)
        ScaledLogLikelihood = Minus(LSTMoutputW, logPrior, tag='output')    // sadly we can't say x - y since we want to assign a tag
    ]
]
